#! /usr/bin/env python3
#
# Copyright 2009-2019 The VOTCA Development Team (http://www.votca.org)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

VERSION='@PROJECT_VERSION@ #CSG_GIT_ID#'
import sys
import os
import numpy as np
import lxml.etree as lxml
import time
import argparse

PROGTITLE = 'THE VOTCA::XTP creator aux-basissetfiles'
PROGDESCR = 'Creates votca xml aux-basissetfiles from votca basissetfiles'
VOTCAHEADER = '''\
==================================================
========   VOTCA (http://www.votca.org)   ========
==================================================

{progtitle}

please submit bugs to bugs@votca.org 
xtp_makeauxbasis, version {version}

'''.format(version=VERSION, progtitle=PROGTITLE)


# =============================================================================
# PROGRAM OPTIONS
# =============================================================================

class XtpHelpFormatter(argparse.HelpFormatter):
    def _format_usage(self, usage, action, group, prefix):
        return VOTCAHEADER

progargs = argparse.ArgumentParser(prog='xtp_makeauxbasis',
    formatter_class=lambda prog: XtpHelpFormatter(prog,max_help_position=70),
    description=PROGDESCR)
    
progargs.add_argument('-f', '--basisfile',   
    action='store',
    required=True,
    type=argparse.FileType('r'),
    help='xtp basissetfile')

progargs.add_argument('-o', '--outfile', 
    type=str,
    default='',
    help='optional file to write basisset to')

progargs.add_argument('-g', '--grouping',
    required=False,
    type=float,
    default=0.1,
    help='Cutoff at which basisfunctions are grouped together in deviation from the arithmetic mean; Default 0.1')

progargs.add_argument('-c', '--cutoff',
    required=False,
    type=float,
    default=60,
    help='Cutoff for very localised basisfunctions; Default 60')

progargs.add_argument('-e','--element',
    required=False,
    type=str,
    default="",
    help='Print out only the element specified')

progargs.add_argument('-l', '--lmax',
    required=False,
    type=int,
    default=4,
    help='maximum angular momentum in aux basisset: Default:4')
    

args=progargs.parse_args()

outputfile=False
if (args.outfile!=""):
    outputfile=args.outfile
onlyoneelement=False
if(args.element!=""):
    onlyoneelement=args.element

if args.lmax>6:
    print ("ERROR: Higher order functions than H (L=6) are not implemented")
    sys.exit(1)



class basisfunction(object):

    def __init__(self,decay,L):
        self.decay=decay
        self.L=L
        self.Ls=[]


class collection(object):

    def __init__(self):
        self.functions=[]
        self.com=0
        self.decay=0
        
    def addbasisfunction(self,function):
        self.functions.append(function)
        self.calcmean()
        self.calcdecay()
        return

    def addcollection(self,collection):
        self.functions+=collection.functions
        self.calcmean()
        self.calcdecay()
        return

    def calcmean(self):
        mean=0
        for i in self.functions:
            mean+=i.decay
        mean=mean/len(self.functions)
        self.com=mean        
        return


    def size(self):
        return len(self.functions)

    def calcdecay(self):
        decay=1
        for i in self.functions:
            decay=decay*i.decay
        self.decay=decay**(1/float(len(self.functions)))

        return

    def popbasisfunction(self):
        keep=[]
        pop=[]
        for i in self.functions:
            outside=isclose(i,self)
            if outside:
                pop.append(i)
            else:
                keep.append(i)
        self.functions=keep
        return pop

    def findLs(self):
        Ls=[]
        for f in self.functions:
            if f.L not in Ls:
                Ls.append(f.L)
        Ls.sort()
        self.Ls=Ls
        return

    def shellstring(self):
        self.findLs()
        string=""
        for l in self.Ls:
            string+=LtoShelltype(l)
        return string

    def checkcompleteness(self):
        return (len(self.Ls)==(self.Ls[-1]-self.Ls[0]+1))


    def splitshell(self):
        newcollections=[]
        breakpointLs=[]
        if( not self.checkcompleteness()):
            if(len(self.Ls)>1):
                for Ls1,Ls2 in zip(self.Ls[:-1],self.Ls[1:]):            
                    if Ls1+1!=Ls2:
                        breakpointLs.append(Ls1)
                for point in reversed(breakpointLs):
                    keep=[]
                    newcollection=collection()
                    for function in self.functions:
                        if function.L<=point:
                            keep.append(function)
                        else:
                            newcollection.addbasisfunction(function)
                    self.functions=keep
                    newcollections.append(newcollection) 
        self.calcmean()
        self.calcdecay()                           
        return newcollections              

    def attachshelltoxml(self,child):
        shell=lxml.SubElement(child,"shell",type=self.shellstring(),scale="1.0")
        constant=lxml.SubElement(shell,"constant",decay="{:1.6e}".format(self.decay))
        for l in self.Ls:
            contraction=lxml.SubElement(constant,"contractions",type=LtoShelltype(l),factor="1.0")
            
        

def isclose(function,collection):
    inside=False
    for f in collection.functions:
        if (abs(np.log(function.decay)-np.log(f.decay))<args.grouping*(1+np.log(function.L+1))):
            inside=True
            break
    #inside=((abs(function.decay-collection.decay)/collection.decay)<args.grouping)
    return inside


def isclosecol(col1,col2):
    inside=False
    for f in col1.functions:
        for g in col2.functions:
            if(abs(np.log(f.decay)-np.log(g.decay))<args.grouping):
                inside=True
                break
    return inside

def ShelltypetoL(shelltype):
    Shelltype2L={"S":0,"P":1,"D":2,"F":3,"G":4,"H":5,"I":6}
    return Shelltype2L[shelltype]

def LtoShelltype(L):
    L2Shelltype={0:"S",1:"P",2:"D",3:"F",4:"G",5:"H",6:"I"}
    return L2Shelltype[L]

def createauxfunctions(basisfunctions):
    auxfunctions=[]
    for i in xrange(len(basisfunctions)):
        for j in xrange(i,len(basisfunctions)):
            decay=basisfunctions[i].decay+basisfunctions[j].decay    
            L=basisfunctions[i].L+basisfunctions[j].L
            if(L>args.lmax):
                continue
            if(decay>args.cutoff):
                continue
            auxfunctions.append(basisfunction(decay,L))

    print ("{} aux functions created".format(len(auxfunctions)))
    return auxfunctions

def sortauxfunctions(auxfunctions):
    Lmax=0
    for function in auxfunctions:
        if(function.L)>Lmax:
            Lmax=function.L
    print ("Maximum angular momentum in aux basisset is {}".format(Lmax))
    functionsperL=[]
    for i in range(Lmax+1):
        functionsperL.append([])

    for function in auxfunctions:
        functionsperL[function.L].append(function) 

    for i,L in enumerate(functionsperL):
        print ("{} functions have L={}".format(len(L),i))
    return functionsperL

def readinbasisset(xmlfile):

    print ("Parsing  {}\n".format(xmlfile))
    parser=lxml.XMLParser(remove_comments=True)
    tree = lxml.parse(xmlfile,parser)
    root = tree.getroot()
    print ("Basisset name is {}".format(root.get("name")))
    output = lxml.Element("basis",name="aux-"+root.get("name"))
    for element in root.iter('element'): 
        if onlyoneelement==False or element.get("name")==onlyoneelement:
            output.append(processelement(element))    
    return output

def clusterfunctionsperL(functions):
    functions.sort(key=lambda function: function.decay,reverse=True)
    collections=[]
    for function in functions:
        if not collections:
            #print "New"
            firstcollection=collection()
            firstcollection.addbasisfunction(function)
            collections.append(firstcollection)
        else:
            for col in collections:
                if isclose(function,col):
                    #print "CLOSE",col.decay,function.decay
                    col.addbasisfunction(function)
                    break
            else:
                    newcollection=collection()
                    newcollection.addbasisfunction(function)
                    collections.append(newcollection)
                    
            
    print ("Reduced functions from {} to {}".format(len(functions),len(collections)))
    for i,col in enumerate(collections):
        print ("{} Functions: {} Decay {}".format(i,col.size(),col.decay))
    return collections



def clustercollections(collections):
    collections.sort(key=lambda col: col.decay,reverse=True)
    finalshells=[]
    for col in collections:
        if not finalshells:
            finalshells.append(col)
        else:
            for shell in finalshells:
                if isclosecol(shell,col):
                    shell.addcollection(col)
                    break
            else:
                finalshells.append(col)
    print ("Reduced functions from {} to {}".format(len(collections),len(finalshells)))
    for i,col in enumerate(finalshells):
        print ("{} Functions: {} Decay {} Type {}".format(i,col.size(),col.decay, col.shellstring()))
    return finalshells
            




def processelement(elementxml):
    print("\nReading element {}".format(elementxml.get("name")))
    
    basisfunctions=[]
    for shell in elementxml.iter('shell'):
        for constant in shell.iter('constant'):
            decay=float(constant.get("decay"))
            for contraction in constant.iter("contractions"):
                shelltype=contraction.get("type")
                basisfunctions.append(basisfunction(decay,ShelltypetoL(shelltype)))

    print("Found {} basisfunctions".format(len(basisfunctions)))
    auxfunctions=createauxfunctions(basisfunctions)
    functionsperL=sortauxfunctions(auxfunctions)
    collections=[]
    for L,functions in enumerate(functionsperL):
        print ("Clustering L={}".format(L))
        collections+=clusterfunctionsperL(functions)
    print ("Collecting shells over Ls")
    shells=clustercollections(collections)
    print ("Collecting shells over Ls twice")
    finalshells=clustercollections(shells)
    splittshells=[]
    for shell in finalshells:
        splittshells+=shell.splitshell()
    finalshells+=splittshells
    outputxml = lxml.Element("element",name=elementxml.get("name"),help="Created by make-auxbasis.py with Lmax={},c={} and g={}".format(args.lmax,args.cutoff,args.grouping))
    for shell in finalshells:
        shell.attachshelltoxml(outputxml)
    return outputxml
            

        
outputxml=readinbasisset(args.basisfile)
if outputfile==False:
    print (lxml.tostring(outputxml, pretty_print=True))
else:
    with open(outputfile, 'w') as f:
        f.write(lxml.tostring(outputxml, pretty_print=True))
    
    




        
        
    
    
    
